-
Notifications
You must be signed in to change notification settings - Fork 258
[reland][ROCm] preshuffled weight mm #2044
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2044
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@mxz297 @jerryzh168 please re-review, kick of CI, thanks. |
@jeffdaily "test-mps-ops" still seems to be failing to compile with
I wonder if we should just guard the whole source file under #if USE_ROCM |
Done. |
@mxz297 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@mxz297 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
1 similar comment
@mxz297 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
* [ROCm][experimental] pre-shuffle weights * add custom gemm op * pass through swizzled * copy paste bug causing extra matmul to execute * correct transpose and permute logic * swizzle.cpp is rocm-only, remove #ifndef USE_ROCM * transpose is shallow, don't unswizzle/swizzle * add fp8 swizzle * remove print statement * setup.py missing check for vec ext * remove merge mistake * conditionalize building sparse marlin for hip * ruff format * ruff check --fix * protect swizzle.cpp inside USE_ROCM * patch from @mxz297
is this not fixed?
|
@jeffdaily @jerryzh168 This is very strange.... So, previously, this error
only shows up on non-AMD platform, so we added a commit that will guard the whole source file under #if USE_ROCM. And we indeed no longer saw this failure anymore. And i saw clean merge signals before doing the merge. So, i am a little bit surprised by: @jeffdaily Are you able to repro these rocm failures somehow? |
the error appears in internal diff as well: https://www.internalfb.com/diff/D73052566 I think we should revert for now? does this error not appear in the original PR/diff? |
@jerryzh168 replied in the internal diff, but it seems like some other failure, which feels like caused by some other diff, though |
I am seeing the same error in wheel build. |
looking at our code, we have: #if defined(USE_ROCM) in tensor_core_tiled_layout.cu Is the hip include here not gated correctly? |
That looks gated correctly. The CI build is missing -I/opt/rocm for some reason. The header files are there, but flag is missing. |
|
||
__all__ = [ | ||
"dtypes", | ||
"autoquant", | ||
"optim", | ||
"quantize_", | ||
"swizzle", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is this added to top level? should this be in prototype for now?
@@ -0,0 +1,9 @@ | |||
# Copyright (c) Meta Platforms, Inc. and affiliates. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we don't want to create a new folder under torchao for this tensor/op I think..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where do you recommend for it to go?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this prototype? we can add to torchao/prototype for now
This reverts commit 2266451.
Adds SwizzleTensor subclass that wraps a Tensor and reorders the contents to be suitable for HIPBLASLT_ORDER_COL16_4R8. SwizzleTensor intercepts torch.mm and replaces with custom calls to hipblaslt.